{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VAE Analysis - Faces dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from scipy.stats import norm\n",
    "import pandas as pd\n",
    "\n",
    "from models.VAE import VariationalAutoencoder\n",
    "from utils.loaders import load_model, ImageLabelLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run params\n",
    "section = 'vae'\n",
    "run_id = '0001'\n",
    "data_name = 'faces'\n",
    "RUN_FOLDER = 'run/{}/'.format(section)\n",
    "RUN_FOLDER += '_'.join([run_id, data_name])\n",
    "\n",
    "\n",
    "DATA_FOLDER = './data/celeb/'\n",
    "IMAGE_FOLDER = './data/celeb/img_align_celeba/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = (128,128,3)\n",
    "\n",
    "att = pd.read_csv(os.path.join(DATA_FOLDER, 'list_attr_celeba.csv'))\n",
    "\n",
    "imageLoader = ImageLabelLoader(IMAGE_FOLDER, INPUT_DIM[:2])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "att.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vae = load_model(VariationalAutoencoder, RUN_FOLDER)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reconstructing faces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_to_show = 10\n",
    "\n",
    "data_flow_generic = imageLoader.build(att, n_to_show)\n",
    "\n",
    "example_batch = next(data_flow_generic)\n",
    "example_images = example_batch[0]\n",
    "\n",
    "z_points = vae.encoder.predict(example_images)\n",
    "\n",
    "reconst_images = vae.decoder.predict(z_points)\n",
    "\n",
    "fig = plt.figure(figsize=(15, 3))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "\n",
    "for i in range(n_to_show):\n",
    "    img = example_images[i].squeeze()\n",
    "    sub = fig.add_subplot(2, n_to_show, i+1)\n",
    "    sub.axis('off')        \n",
    "    sub.imshow(img)\n",
    "\n",
    "for i in range(n_to_show):\n",
    "    img = reconst_images[i].squeeze()\n",
    "    sub = fig.add_subplot(2, n_to_show, i+n_to_show+1)\n",
    "    sub.axis('off')\n",
    "    sub.imshow(img)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Latent space distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_test = vae.encoder.predict_generator(data_flow_generic, steps = 20, verbose = 1)\n",
    "\n",
    "x = np.linspace(-3, 3, 100)\n",
    "\n",
    "fig = plt.figure(figsize=(20, 20))\n",
    "fig.subplots_adjust(hspace=0.6, wspace=0.4)\n",
    "\n",
    "for i in range(50):\n",
    "    ax = fig.add_subplot(5, 10, i+1)\n",
    "    ax.hist(z_test[:,i], density=True, bins = 20)\n",
    "    ax.axis('off')\n",
    "    ax.text(0.5, -0.35, str(i), fontsize=10, ha='center', transform=ax.transAxes)\n",
    "    ax.plot(x,norm.pdf(x))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Newly generated faces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_to_show = 30\n",
    "\n",
    "znew = np.random.normal(size = (n_to_show,vae.z_dim))\n",
    "\n",
    "reconst = vae.decoder.predict(np.array(znew))\n",
    "\n",
    "fig = plt.figure(figsize=(18, 5))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "for i in range(n_to_show):\n",
    "    ax = fig.add_subplot(3, 10, i+1)\n",
    "    ax.imshow(reconst[i, :,:,:])\n",
    "    ax.axis('off')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_vector_from_label(label, batch_size):\n",
    "\n",
    "    data_flow_label = imageLoader.build(att, batch_size, label = label)\n",
    "\n",
    "    origin = np.zeros(shape = vae.z_dim, dtype = 'float32')\n",
    "    current_sum_POS = np.zeros(shape = vae.z_dim, dtype = 'float32')\n",
    "    current_n_POS = 0\n",
    "    current_mean_POS = np.zeros(shape = vae.z_dim, dtype = 'float32')\n",
    "\n",
    "    current_sum_NEG = np.zeros(shape = vae.z_dim, dtype = 'float32')\n",
    "    current_n_NEG = 0\n",
    "    current_mean_NEG = np.zeros(shape = vae.z_dim, dtype = 'float32')\n",
    "\n",
    "    current_vector = np.zeros(shape = vae.z_dim, dtype = 'float32')\n",
    "    current_dist = 0\n",
    "\n",
    "    print('label: ' + label)\n",
    "    print('images : POS move : NEG move :distance : 𝛥 distance')\n",
    "    while(current_n_POS < 10000):\n",
    "\n",
    "        batch = next(data_flow_label)\n",
    "        im = batch[0]\n",
    "        attribute = batch[1]\n",
    "\n",
    "        z = vae.encoder.predict(np.array(im))\n",
    "\n",
    "        z_POS = z[attribute==1]\n",
    "        z_NEG = z[attribute==-1]\n",
    "\n",
    "        if len(z_POS) > 0:\n",
    "            current_sum_POS = current_sum_POS + np.sum(z_POS, axis = 0)\n",
    "            current_n_POS += len(z_POS)\n",
    "            new_mean_POS = current_sum_POS / current_n_POS\n",
    "            movement_POS = np.linalg.norm(new_mean_POS-current_mean_POS)\n",
    "\n",
    "        if len(z_NEG) > 0: \n",
    "            current_sum_NEG = current_sum_NEG + np.sum(z_NEG, axis = 0)\n",
    "            current_n_NEG += len(z_NEG)\n",
    "            new_mean_NEG = current_sum_NEG / current_n_NEG\n",
    "            movement_NEG = np.linalg.norm(new_mean_NEG-current_mean_NEG)\n",
    "\n",
    "        current_vector = new_mean_POS-new_mean_NEG\n",
    "        new_dist = np.linalg.norm(current_vector)\n",
    "        dist_change = new_dist - current_dist\n",
    "\n",
    "\n",
    "        print(str(current_n_POS)\n",
    "              + '    : ' + str(np.round(movement_POS,3))\n",
    "              + '    : ' + str(np.round(movement_NEG,3))\n",
    "              + '    : ' + str(np.round(new_dist,3))\n",
    "              + '    : ' + str(np.round(dist_change,3))\n",
    "             )\n",
    "\n",
    "        current_mean_POS = np.copy(new_mean_POS)\n",
    "        current_mean_NEG = np.copy(new_mean_NEG)\n",
    "        current_dist = np.copy(new_dist)\n",
    "\n",
    "        if np.sum([movement_POS, movement_NEG]) < 0.08:\n",
    "            current_vector = current_vector / current_dist\n",
    "            print('Found the ' + label + ' vector')\n",
    "            break\n",
    "\n",
    "    return current_vector   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_vector_to_images(feature_vec):\n",
    "\n",
    "    n_to_show = 5\n",
    "    factors = [-4,-3,-2,-1,0,1,2,3,4]\n",
    "\n",
    "    example_batch = next(data_flow_generic)\n",
    "    example_images = example_batch[0]\n",
    "    example_labels = example_batch[1]\n",
    "\n",
    "    z_points = vae.encoder.predict(example_images)\n",
    "\n",
    "    fig = plt.figure(figsize=(18, 10))\n",
    "\n",
    "    counter = 1\n",
    "\n",
    "    for i in range(n_to_show):\n",
    "\n",
    "        img = example_images[i].squeeze()\n",
    "        sub = fig.add_subplot(n_to_show, len(factors) + 1, counter)\n",
    "        sub.axis('off')        \n",
    "        sub.imshow(img)\n",
    "\n",
    "        counter += 1\n",
    "\n",
    "        for factor in factors:\n",
    "\n",
    "            changed_z_point = z_points[i] + feature_vec * factor\n",
    "            changed_image = vae.decoder.predict(np.array([changed_z_point]))[0]\n",
    "\n",
    "            img = changed_image.squeeze()\n",
    "            sub = fig.add_subplot(n_to_show, len(factors) + 1, counter)\n",
    "            sub.axis('off')\n",
    "            sub.imshow(img)\n",
    "\n",
    "            counter += 1\n",
    "\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 500\n",
    "attractive_vec = get_vector_from_label('Attractive', BATCH_SIZE)\n",
    "mouth_open_vec = get_vector_from_label('Mouth_Slightly_Open', BATCH_SIZE)\n",
    "smiling_vec = get_vector_from_label('Smiling', BATCH_SIZE)\n",
    "lipstick_vec = get_vector_from_label('Wearing_Lipstick', BATCH_SIZE)\n",
    "young_vec = get_vector_from_label('High_Cheekbones', BATCH_SIZE)\n",
    "male_vec = get_vector_from_label('Male', BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "eyeglasses_vec = get_vector_from_label('Eyeglasses', BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blonde_vec = get_vector_from_label('Blond_Hair', BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# print('Attractive Vector')\n",
    "# add_vector_to_images(attractive_vec)\n",
    "\n",
    "# print('Mouth Open Vector')\n",
    "# add_vector_to_images(mouth_open_vec)\n",
    "\n",
    "# print('Smiling Vector')\n",
    "# add_vector_to_images(smiling_vec)\n",
    "\n",
    "# print('Lipstick Vector')\n",
    "# add_vector_to_images(lipstick_vec)\n",
    "\n",
    "# print('Young Vector')\n",
    "# add_vector_to_images(young_vec)\n",
    "\n",
    "# print('Male Vector')\n",
    "# add_vector_to_images(male_vec)\n",
    "\n",
    "print('Eyeglasses Vector')\n",
    "add_vector_to_images(eyeglasses_vec)\n",
    "\n",
    "# print('Blond Vector')\n",
    "# add_vector_to_images(blonde_vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def morph_faces(start_image_file, end_image_file):\n",
    "\n",
    "    factors = np.arange(0,1,0.1)\n",
    "\n",
    "    att_specific = att[att['image_id'].isin([start_image_file, end_image_file])]\n",
    "    att_specific = att_specific.reset_index()\n",
    "    data_flow_label = imageLoader.build(att_specific, 2)\n",
    "\n",
    "    example_batch = next(data_flow_label)\n",
    "    example_images = example_batch[0]\n",
    "    example_labels = example_batch[1]\n",
    "\n",
    "    z_points = vae.encoder.predict(example_images)\n",
    "\n",
    "\n",
    "    fig = plt.figure(figsize=(18, 8))\n",
    "\n",
    "    counter = 1\n",
    "\n",
    "    img = example_images[0].squeeze()\n",
    "    sub = fig.add_subplot(1, len(factors)+2, counter)\n",
    "    sub.axis('off')        \n",
    "    sub.imshow(img)\n",
    "\n",
    "    counter+=1\n",
    "\n",
    "\n",
    "    for factor in factors:\n",
    "\n",
    "        changed_z_point = z_points[0] * (1-factor) + z_points[1]  * factor\n",
    "        changed_image = vae.decoder.predict(np.array([changed_z_point]))[0]\n",
    "\n",
    "        img = changed_image.squeeze()\n",
    "        sub = fig.add_subplot(1, len(factors)+2, counter)\n",
    "        sub.axis('off')\n",
    "        sub.imshow(img)\n",
    "\n",
    "        counter += 1\n",
    "\n",
    "    img = example_images[1].squeeze()\n",
    "    sub = fig.add_subplot(1, len(factors)+2, counter)\n",
    "    sub.axis('off')        \n",
    "    sub.imshow(img)\n",
    "\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_image_file = '000238.jpg' \n",
    "end_image_file = '000193.jpg' #glasses\n",
    "\n",
    "morph_faces(start_image_file, end_image_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_image_file = '000112.jpg'\n",
    "end_image_file = '000258.jpg'\n",
    "\n",
    "morph_faces(start_image_file, end_image_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_image_file = '000230.jpg'\n",
    "end_image_file = '000712.jpg'\n",
    "\n",
    "morph_faces(start_image_file, end_image_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdl_code",
   "language": "python",
   "name": "gdl_code"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
